#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
A tool for plotting Google benchmark results using matplotlib. Requires
Python3, matplotlib and gbench data in JSON format.

Does a few nice things for you:
    1. Can be used with file input (cmd line arg) or reading from stdin
    2. Groups benchmarks into multiple plots based on benchmark name.
       This is currently implemented to work well with template based
       benchmarks, it might required mods.
    3. Automatically detects the need for log-scale on both axes.
    4. Displaying plots or saving them all to a folder.

Missing features:
    1. Proper support for benchmarks that use two arguments.
    2. Proper handling for all types of benchmark structures, name parsing
       in this implementation is made for template-based benches.

Usage:
    # Generate benchmark data in json format using:
    > ./my_bench --benchmark_out_format=json --benchmark_out=data.json
    # Use that data with plotter:
    > ./plot_bench_json data.json

Alternatively you can route stuff and avoid using an intermediary file:
    sh > ./my_bench --benchmark_out_format=json
         --benchmark_out=/dev/stderr 2>&1 >/dev/null
         | grep "^[{} ]" | plot_gbench_json

Maybe there is a nicer way to route it?
"""

import sys
import re
import os
import json
from argparse import ArgumentParser
from collections import defaultdict

from matplotlib import pyplot as plt


def parse_args():
    argp = ArgumentParser(description=__doc__)
    argp.add_argument("input_file", nargs="?",
                      help="Path to file with JSON data. If not provided data "
                      "is read from stdin")
    argp.add_argument("--output-path", required=False,
                      help="Path to a folder where the plots should be saved. "
                      " If not provided the plots are displayed in the GUI.")
    argp.add_argument("--output-type", required=False, default="png",
                      help="If saving plot files use this extension.")
    return argp.parse_args()


def convert_num(string):
    """
    Converts stuff like "100" and "3k" to numbers.
    """
    suffix_re = re.search("\D+$", string)

    if not suffix_re:
        return float(string)

    suffix = string[suffix_re.start():]
    number = float(string[:suffix_re.start()])

    if suffix == "k":
        number *= 1000
    else:
        raise ValueError("Unknown number suffix: " + suffix)

    return number


def is_exponential_growth(numbers):
    """
    Tries to determine if the given numbers progress more in logarithmic then
    in linear fashion. Assumes numbers increase monotonically.
    """
    diffs = [n2 - n1 for (n1, n2) in zip(numbers, numbers[1:])]
    factors = [n2 / n1 for (n1, n2) in zip(numbers, numbers[1:])]

    # constant diff implies linear increase, constant factor implies exp
    # which is more constant?
    diff_rms = [(d - (sum(diffs) / len(diffs))) ** 2 for d in diffs]
    factor_rms = [(f - (sum(factors) / len(factors))) ** 2 for f in factors]
    return sum(factor_rms) < sum(diff_rms)


def main():
    args = parse_args()
    if args.input_file:
        with open(args.input_file) as f:
            data = json.load(f)
    else:
        data = json.load(sys.stdin)

    # structure: {bench_name: [(x, y, time_unit), ...]
    benchmarks = defaultdict(list)

    for bench in data["benchmarks"]:
        name, x = bench["name"].rsplit("/", 1)
        benchmarks[name].append((convert_num(x), bench["real_time"],
                                bench["time_unit"]))

    # group benchmarks on name prefix
    # one group will be one plot with possibly multiple lines
    benchmarks_groups = defaultdict(dict)
    for name, data in benchmarks.items():
        name_split = re.split("\W", name, 1)
        if len(name_split) == 2:
            group, element = name_split
            benchmarks_groups[group][element] = data
        else:
            benchmarks_groups["__all_benchmarks__"][name] = data

    # validate all the time units per group (one plot)
    for measurements in benchmarks_groups.values():
        units = set()
        for measurement in measurements.values():
            units.update(k[2] for k in measurement)
        if len(units) > 1:
            raise ValueError(
                "Multiple time units in a single plot: %r" % units)

    # plot all groups
    for group_name, measurements in benchmarks_groups.items():
        plt.figure()
        log_x, log_y = False, False
        for line, values in measurements.items():
            x, y, _ = zip(*values)
            log_x |= is_exponential_growth(x)
            log_y |= is_exponential_growth(y)
            plt.plot(x, y, label=line)
        if log_x:
            plt.xscale("log")
        if log_y:
            plt.yscale("log")
        plt.title(group_name)
        plt.legend()
        plt.grid()
        if args.output_path:
            if not os.path.exists(args.output_path):
                os.makedirs(args.output_path, exist_ok=True)
            plt.savefig(os.path.join(
                args.output_path, group_name + "." + args.output_type))
        else:
            plt.show()


if __name__ == "__main__":
    main()
